/*	Copyright 2007 - Xavier Baro (xbaro@cvc.uab.cat)

	This file is part of eapmlib.

    Eapmlib is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 3 of the License, or any 
	later version.

    Eapmlib is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/
#include "AdaBoost.h"
#include <math.h>

/**********************************************************************
****************     CABWeakLearner Definition       ******************
**********************************************************************/
Evolutive::CABWeakLearner::CABWeakLearner(void)
{
#ifndef USE_OPENCV
	OPENCV_ERROR_CLASS("CABWeakLearner");
#endif
}

Evolutive::CABWeakLearner::~CABWeakLearner(void)
{
}

/*****************************************************************
****************     CAdaBoost Definition       ******************
*****************************************************************/
Evolutive::CAdaBoost::CAdaBoost(void) : m_Version(AB_GENTLE), m_MaxIters(100), m_Labels(NULL), m_WeakLearner(NULL),
				    				m_Weights(NULL), m_ValidationPercentage(0), m_ValidationIdx(NULL), m_NumSamples(0),
									m_NumValidationSamples(0), m_NewClassLabels(NULL), m_SamplesMask(NULL),m_Ensemble(NULL)
{
#ifndef USE_OPENCV
	OPENCV_ERROR_CLASS("CAdaBoost");
#endif
}

Evolutive::CAdaBoost::~CAdaBoost(void)
{	
	if(m_Labels)
		delete[] m_Labels;
	if(m_SamplesMask)
		delete[] m_SamplesMask;
	if(m_NewClassLabels)
		delete[] m_NewClassLabels;
	if(m_Weights)
		delete[] m_Weights;
	if(m_ValidationIdx)
		delete[] m_ValidationIdx;
	if(m_Ensemble)
		delete m_Ensemble;
}

#ifdef USE_OPENCV
void Evolutive::CAdaBoost::SetMaxIters(int NumIters)
{
	// Store the number of iterations
	m_MaxIters=NumIters;
}

void Evolutive::CAdaBoost::SetVersion(Evolutive::AB_VERSION Version)
{
	// Stores the version of Adaboost to use
	m_Version=Version;
}

void Evolutive::CAdaBoost::SetWeakLearner(Evolutive::CABWeakLearner *WeakLearner)
{
	// Point to the weak learner
	m_WeakLearner=WeakLearner;
}

void Evolutive::CAdaBoost::SetValidationPercentage(double Percentage)
{
	// Stores the percentage of training data that will be used as validation data
	m_ValidationPercentage=Percentage;
}

Evolutive::CAdditiveClassEnsemble *Evolutive::CAdaBoost::GetEnsemblePtr(void)
{
	// Return a pointer to the current ensemble
	return m_Ensemble;
}

Evolutive::CAdditiveClassEnsemble *Evolutive::CAdaBoost::GetEnsemble(void)
{
	CAdditiveClassEnsemble *Aux;

	// Store the actual ensemble pointer
	Aux=m_Ensemble;

	// Put the ensemble pointer to NULL
	m_Ensemble=NULL;

	return Aux;
}

void Evolutive::CAdaBoost::Reset(void)
{
	// Performs a reset of the ensemble
	m_Ensemble->Reset();
}

void Evolutive::CAdaBoost::Initialize(void)
{
	double WeightVal;
	register int i;

	// Verify that a weak learner is available
	if(!m_WeakLearner)
		throw CEvolutiveLibException("Undefined Weak Learner",__FILE__,__LINE__,"Initialize");

	// Obtain the training set information
	SetLabels(m_WeakLearner->GetNumSamples(),m_WeakLearner->GetLabelsPtr());

	// Verify that the labels list is available
	if(!m_Labels)
		throw CEvolutiveLibException("Undefined training set labels",__FILE__,__LINE__,"Initialize");

	// Remove old memory allocation
	if(m_Weights)
		delete[] m_Weights;
	if(m_ValidationIdx)
		delete[] m_ValidationIdx;
	if(m_NewClassLabels)
		delete[] m_NewClassLabels;
	if(m_Ensemble)
		delete m_Ensemble;

	// Obtain the number of samples
	m_NumSamples=m_WeakLearner->GetNumSamples();

	// Create the validation set
	BuildValidationSet();

	// Allocate the weights memory
	m_Weights=new double[m_NumSamples];	

	// Calculate the initial Weight value.
	WeightVal=1.0/(m_NumSamples-m_NumValidationSamples);
	
	// Initialize the weights
	for(i=0;i<m_NumSamples;i++)
	{
		m_Weights[i]=WeightVal;
	}

	// When a validation set is defined set the weights of validation samples to 0.
	for(i=0;i<m_NumValidationSamples;i++)
	{
		m_Weights[m_ValidationIdx[i]]=0;
	}

	// Allocate the memory to store the most recent classifier results
	m_NewClassLabels=new int[m_NumSamples];

	// Create an empty ensemble
	m_Ensemble=new CAdditiveClassEnsemble();
}

void Evolutive::CAdaBoost::SetLabels(int NumSamples,int *Labels)
{
	register int i;
	
	// Store the number of samples
	m_NumSamples=NumSamples;
		
	// Allocate the labels memory
	if(m_Labels)
		delete[] m_Labels;
	m_Labels=new int[m_NumSamples];

	// Verify and store the labels
	for(i=0;i<m_NumSamples;i++)
	{	
		// Verify that the label is in the set {-1,1}
		if(Labels[i]!=1 && Labels[i]!=-1)
			throw CEvolutiveLibException("Invalid value for a training label.",__FILE__,__LINE__,"SetLabels");
		
		// Store the label
		m_Labels[i]=Labels[i];
	}
}

void Evolutive::CAdaBoost::BuildValidationSet(void)
{
	register int i,Idx;
	int NumPos=0,NumNeg=0;
	int NumPosSmp=0,NumNegSmp=0;
	double UseProb;

	// Release old memory
	if(m_SamplesMask)
		delete[] m_SamplesMask;
	if(m_ValidationIdx)
		delete[] m_ValidationIdx;

	// Calculate the size of the validation set
	m_NumValidationSamples=static_cast<int>(floor(m_NumSamples*m_ValidationPercentage));

	// If no validation set is requested exit
	if(!m_NumValidationSamples)
	{
		m_SamplesMask=NULL;
		m_ValidationIdx=NULL;
		return;
	}

	// Allocate the mask memory
	m_SamplesMask=new bool[m_NumSamples];

	// Obtain the percentage of positive and negative samples in the original training set
	for(i=0;i<m_NumSamples;i++)
	{
		// Count the sample into their class
		if(m_Labels[i]>0)
			NumPos++;
		else
			NumNeg++;

		// Set this sample as a valid sample
		m_SamplesMask[i]=true;
	}

	// Calculate the number of samples for each set
	NumPosSmp=m_NumSamples*(m_NumSamples/NumPos);
	NumNegSmp=m_NumValidationSamples-NumPosSmp;

	// Set the samples mask
	Idx=0;
	while(NumPosSmp>0 || NumNegSmp)
	{
		// If we arrive to the end of the samples go to the beginning again
		if(Idx>=m_NumSamples)
			Idx=0;

		// If this sample is used go to the next
		if(!m_SamplesMask[Idx])
			continue;

		// Get the probability of using this sample
		UseProb=RAND_VALUE();

		// If the probability is smaller than 50% continue
		if(UseProb<0.5)
			continue;

		// If we need more examples of the class of this one use it
		if(m_Labels[i]>0 && NumPosSmp>0)
		{
			m_SamplesMask[Idx]=false;
			NumPosSmp--;
		}
		if(m_Labels[Idx]<0 && NumNegSmp>0)
		{
			m_SamplesMask[Idx]=false;
			NumNegSmp--;
		}
	}

	// Create the fast access to the validation samples

	// Allocate the memory
	m_ValidationIdx=new int[m_NumValidationSamples];
	
	// Create the index list
	Idx=0;
	for(i=0;i<m_NumSamples;i++)
	{
		if(!m_SamplesMask[i])
			m_ValidationIdx[Idx]=i;
		Idx++;
	}
}

Evolutive::CAdditiveClassEnsemble* Evolutive::CAdaBoost::LearnEnsemble(void)
{
	register int m;

	// Initialize the learner
	Initialize();

	// Build a new ensemble
	for(m=0;m<m_MaxIters;m++)
	{
		// Shows process information
		cout << "Adaboost Iteration: " << m << endl;
		
		// To allow a step by step learning, all the steps are defined in another function. We call it at each iteration.
		AddClassifier();
	}

	return GetEnsemble();
}

void Evolutive::CAdaBoost::AddClassifier(void)
{
	register int i=0;
	CClassifier *NewClassifier;
	double Weight;
	double Error;
	double WSum=0;
	double TrainError,TrainBER;
	double ValidationError,ValidationBER;

	// Find the new weak classifier
	NewClassifier=m_WeakLearner->GetWeakClassifier(m_Weights,m_SamplesMask,m_NewClassLabels);

	// Evaluate the error
	Error=0;
	for(i=0;i<m_NumSamples;i++)
	{
		// Verify the label
		if(m_NewClassLabels[i]!=-1 && m_NewClassLabels[i]!=1)
		{
			// The classifier returned an invalid class
			throw CEvolutiveLibException("The Weak Classifier returned an unexpected class value.",__FILE__,__LINE__,"AddClassifier");
		}
	
		if(m_SamplesMask)
		{
			if(m_SamplesMask[i])
			{
				// Update the error value
				if(m_NewClassLabels[i]!=m_Labels[i])
					Error+=m_Weights[i];
			}
		}
		else
		{
			// Update the error value
			if(m_NewClassLabels[i]!=m_Labels[i])
				Error+=m_Weights[i];
		}
	}

	// Update the samples weights
	switch(m_Version)
	{
	case Evolutive::AB_DISCRETE:		
		for(i=0;i<m_NumSamples;i++)
		{
			if(m_SamplesMask)
			{
				if(m_SamplesMask[i])				
				{
					m_Weights[i]*=1;
				}
			}
			else
			{
				m_Weights[i]*=1;
			}
		}
		break;
	case Evolutive::AB_GENTLE:
		for(i=0;i<m_NumSamples;i++)
		{
			// Update the weight
			if(m_SamplesMask)
			{
				if(m_SamplesMask[i])
				{
					m_Weights[i]*=exp(-(double)(m_Labels[i]*m_NewClassLabels[i]));
				}				
			}
			else
			{
				m_Weights[i]*=exp(-(double)(m_Labels[i]*m_NewClassLabels[i]));
			}

			// Add the weight to the normalization factor
			WSum+=m_Weights[i];
		}
		break;
	}

	// Set the classifier weight
	switch(m_Version)
	{
	case Evolutive::AB_DISCRETE:		
		Weight=1.0;		
		break;
	case Evolutive::AB_GENTLE:
		Weight=1.0;
		break;
	}

	// Normalize the weights
	for(i=0;i<m_NumSamples;i++)
		m_Weights[i]/=WSum;

	// Add the new classifier to the ensemble
	m_Ensemble->AddClassifier(NewClassifier,Weight);

	// Show the final error	
	GetError(TrainError,TrainBER,ValidationError,ValidationBER);

	cout << "New added classifier error: " << Error << endl;
	cout << "Ensemble Error: " << endl;
	cout << "   Train set: " << TrainError << "=> BER = " << TrainBER << endl;
	if(m_SamplesMask)
		cout << "   Validation set: " << ValidationError << "=> BER = " << ValidationBER << endl;
	else
		cout << "   Validation set: Not applicable" << endl;
}

void Evolutive::CAdaBoost::GetError(double &TrainError,double &TrainBER,double &ValidationError,double &ValidationBER)
{
	int NumPosTrain=0,NumNegTrain=0;
	int NumPosValidation=0,NumNegValidation=0;
	int NumTrainPosMissClass=0,NumTrainNegMissClass=0;
	int NumValPosMissClass=0,NumValNegMissClass=0;
	
	// Evaluate the miss classified samples
	GetClassErrors((CClassifier*)m_Ensemble,NumPosTrain,NumNegTrain,NumTrainPosMissClass,NumTrainNegMissClass,NumPosValidation,NumNegValidation,NumValPosMissClass,NumValNegMissClass);
	
	// Calculate train the errors
	TrainError=static_cast<double>(NumTrainPosMissClass+NumTrainNegMissClass)/static_cast<double>(NumPosTrain+NumNegTrain);
	TrainBER=static_cast<double>(NumTrainPosMissClass)/static_cast<double>(NumPosTrain);
	TrainBER+=static_cast<double>(NumTrainNegMissClass)/static_cast<double>(NumNegTrain);
	TrainBER/=2.0;
	
	// Calculate validation errors
	if(m_SamplesMask)
	{
		ValidationError=static_cast<double>(NumValPosMissClass+NumValNegMissClass)/static_cast<double>(NumPosValidation+NumNegValidation);
		ValidationBER=static_cast<double>(NumValPosMissClass)/static_cast<double>(NumPosValidation);
		ValidationBER+=static_cast<double>(NumValNegMissClass)/static_cast<double>(NumNegValidation);
		ValidationBER/=2.0;
	}
	else
	{
		ValidationError=-1;
		ValidationBER=-1;
	}
}

void Evolutive::CAdaBoost::GetClassErrors(CClassifier *Classifier,int &TrainNumPos,int &TrainNumNeg,int &TrainPosMissSmp,int &TrainNegMissSmp,int &ValNumPos,int &ValNumNeg,int &ValPosMissSmp,int &ValNegMissSmp)
{
	int *NewClasses=NULL;
	register int i;
	bool PosNeg;
	bool Correct;
	
	// Create the final classification vector
	NewClasses=new int[m_WeakLearner->GetNumSamples()];
	
	// Classify all the samples using the given classifier
	m_WeakLearner->Classify(Classifier,NewClasses);
	
	// Initialize the output variables
	TrainNumPos=0;
	TrainNumNeg=0;
	TrainPosMissSmp=0;
	TrainNegMissSmp=0;
	ValNumPos=0;
	ValNumNeg=0;
	ValPosMissSmp=0;
	ValNegMissSmp=0;
		
	// Evaluate the errors
	for(i=0;i<m_WeakLearner->GetNumSamples();i++)
	{
		// Check if the sample is positive or negative
		if(m_Labels[i]>0)
			PosNeg=true;
		else
			PosNeg=false;
		
		// Check if the sample is correctly classified
		if(m_Labels[i]*NewClasses[i]>0)
			Correct=true;
		else
			Correct=false;

		// Check the existence of a validaton set
		if(m_SamplesMask)
		{
			if(PosNeg)
			{
				// Positive sample
				if(m_SamplesMask[i])
				{
					// Training set
					TrainNumPos++;
					
					// Check the given value
					if(!Correct)
						TrainPosMissSmp++;
				}
				else
				{
					// Validation set
					ValNumPos++;
					
					// Check the given value
					if(!Correct)
						ValPosMissSmp++;
				}
			}
			else
			{
				// Negative Sample
				if(m_SamplesMask[i])
				{
					// Training set
					TrainNumNeg++;
					
					// Check the given value
					if(!Correct)
						TrainNegMissSmp++;					
				}
				else
				{
					// Validation set
					ValNumNeg++;
										
					// Check the given value
					if(!Correct)
						ValNegMissSmp++;
				}
			}
		}
		else
		{
			// There are not validation set, all samples are from the training set
			if(PosNeg)
			{
				// Positive sample
				TrainNumPos++;
				
				// Check the given value
				if(!Correct)
					TrainPosMissSmp++;
			}
			else
			{
				// Negative sample
				TrainNumNeg++;
				
				// Check the given value
				if(!Correct)
					TrainNegMissSmp++;
			}
				
		}
	}
	
	// Remove allocated memory
	delete[] NewClasses;
}
#endif //USE_OPENCV